﻿using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.IO;
using System.Linq;
using HuginApiAddonsCS.Utils;
using System.Threading;

namespace HuginApiAddonsCS.Policy.Learning
{
    public class PolicyTreeLearner
    {
        /// <summary>
        /// Learns policy trees from a known time horizon
        /// </summary>
        /// <param name="inputFile">input csv file</param>
        /// <param name="obsCol">observation column headding</param>
        /// <param name="actionCol">action column headding</param>
        /// <param name="modelCol">model column headding</param>
        /// <param name="timeStepCol">time step column headding</param>
        /// <param name="observations">observation column headding</param>
        /// <returns>list of policy trees</returns>
        public List<PolicyNode> LearnFromKnownHorizon(string inputFile, string obsCol, string actionCol, string modelCol, string timeStepCol, List<string> observations)
        {
            //Load CSV file
            CsvParser parser = new CsvParser();
            List<string> headings = parser.GetColumnHeadings(inputFile);
            List<List<string>> rows = parser.ParseRows(inputFile, true);

            int obsIndex = headings.IndexOf(obsCol);
            int actionIndex = headings.IndexOf(actionCol);
            int modelindex = headings.IndexOf(modelCol);
            int timeStepIndex = headings.IndexOf(timeStepCol);

            List<PolicyNode> rootNodes = new List<PolicyNode>();

            List<List<string>> currentBranch = new List<List<string>>();
            int lastTimeStep = 0;

            foreach (List<string> row in rows)
            {
                int timeStep = Convert.ToInt32(row[timeStepIndex]);
                

                if (timeStep < lastTimeStep)
                {
                    rootNodes = AddBranchToPolicyTrees(rootNodes, currentBranch, obsIndex, actionIndex, modelindex, observations);
                    currentBranch.Clear();
                }

                lastTimeStep = timeStep;
                currentBranch.Add(row);
            }

            return rootNodes;
        }

        /// <summary>
        /// Checks to see if any existing branches are compatible, if not creates a new branch then adds it
        /// </summary>
        /// <param name="trees">list of exisiting trees</param>
        /// <param name="branch">branch to add into trees</param>
        /// <param name="obsIndex">observation column index</param>
        /// <param name="actionIndex">action column index</param>
        /// <param name="modelIndex">model column index</param>
        /// <param name="observations">list of potential observations</param>
        /// <returns>list of trees with branch added in appropriate place</returns>
        public List<PolicyNode> AddBranchToPolicyTrees(List<PolicyNode> trees, List<List<string>> branch, int obsIndex, 
            int actionIndex, int modelIndex, List<string> observations)
        {
            List<string> history = BranchToHistory(branch, obsIndex, actionIndex);

            if ((history.Count % 2) != 1)
            {
                //Error if history not an odd number
                return trees;
            }

            string modelName = branch.First()[modelIndex];
            int compatible = -1;

            for (int i = 0; (i < trees.Count) && (compatible == -1); i++)
            {
                PolicyNode root = trees[i];
                if (root.IsHistoryCompatible(history.ToList(), modelName))
                {
                    compatible = i;
                }
            }

            if (compatible != -1)
            {
                trees[compatible].MergeWithHistory(history.ToList(), modelName);
            }
            else
            {
                trees.Add(new PolicyNode(observations, history.ToList(), modelName));
            }

            return trees;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="trees">trees to merge with</param>
        /// <param name="history">history to compare and merge into trees</param>
        /// <param name="observations">list of possible observations</param>
        /// <param name="modelName">model name used for comparison</param>
        /// <returns></returns>
        public List<PolicyNode> AddBranchToPolicyTrees(List<PolicyNode> trees, List<string> history, List<string> observations, string modelName)
        {
            if ((history.Count % 2) != 1)
            {
                //Error if history not an odd number
                return trees;
            }

            int compatible = -1;

            for (int i = 0; (i < trees.Count) && (compatible == -1); i++)
            {
                PolicyNode root = trees[i];
                if (root.IsHistoryCompatible(history.ToList(), modelName))
                {
                    compatible = i;
                }
            }

            if (compatible != -1)
            {
                trees[compatible].MergeWithHistory(history.ToList(), modelName);
            }
            else
            {
                trees.Add(new PolicyNode(observations, history.ToList(), modelName));
            }

            return trees;
        }

        /// <summary>
        /// Converts a branch to a history for later merging
        /// </summary>
        /// <param name="branch">branch/rows to be converted</param>
        /// <param name="obsIndex">observation column index</param>
        /// <param name="actionIndex">action column index</param>
        /// <returns>hisroty</returns>
        public List<string> BranchToHistory(List<List<string>> branch, int obsIndex, int actionIndex)
        {
            List<string> history = new List<string>();

            foreach (List<string> row in branch)
            {
                if (row != branch.First())
                {
                    history.Add(row[obsIndex]);
                }

                history.Add(row[actionIndex]);
            }

            return history;
        }

        /// <summary>
        /// Fills Missing nodes using the hoeffding test
        /// Loops through each branch identifying the nodes with missing child nodes
        /// and those that are complete for all but the last level in the tree
        /// </summary>
        /// <param name="trees">input trees</param>
        /// <param name="epsilon">epsilon compare value</param>
        /// <param name="mergesFile">file to write merges information to</param>
        /// <returns>filled in trees</returns>
        public List<PolicyNode> FillMissingNodesHoeffding(List<PolicyNode> trees, double epsilon, string mergesFile)
        {
            int mergeCount = 0;
            int hoeffdingTests = 0;
            int incompleteNodes = 0;

            foreach (PolicyNode rootNode in trees)
            {
                for (int i = 1; i < rootNode.GetTreeLength(); i++)
                {
                    List<PolicyNode> levelNodes = rootNode.GetNodesForLevel(i);
                    List<PolicyNode> levelCompleteNodes = new List<PolicyNode>();
                    List<PolicyNode> levelIncompleteNodes = new List<PolicyNode>();

                    //Divide nodes into two categories of complete and incomplete nodes
                    foreach (PolicyNode levelNode in levelNodes)
                    {
                        if (levelNode.IsComplete())
                        {
                            levelCompleteNodes.Add(levelNode);
                        }
                        else
                        {
                            levelIncompleteNodes.Add(levelNode);
                        }
                    }

                    //loop through complete nodes and run hoeffding test on all complete nodes
                    foreach (PolicyNode incompleteNode in levelIncompleteNodes)
                    {
                        incompleteNodes++;

                        foreach (PolicyNode completeNode in levelCompleteNodes)
                        {
                            hoeffdingTests++;

                            if (incompleteNode.HoeffdingTest(completeNode, epsilon))
                            {
                                incompleteNode.MergeWithNode(completeNode);
                                mergeCount ++;
                            }
                        }
                    }
                }
            }

            if (mergesFile != string.Empty)
            {
                using (StreamWriter writer = new StreamWriter(mergesFile))
                {
                    writer.WriteLine("NumberOfTrees,IncompleteNodes,HoeffdingTests,Merges,IncompleteNodes/Tree,Merges/Tree,");
                    writer.WriteLine(trees.Count() + "," + incompleteNodes + "," + hoeffdingTests + "," + mergeCount + ","
                        + ((double)incompleteNodes / (double)trees.Count()) + "," + ((double)mergeCount / (double)trees.Count()) + ",");
                }
            }

            Console.WriteLine("Hoeffding Tests : " + hoeffdingTests);
            Console.WriteLine("Merge Count : " + mergeCount);
            
            return trees;
        }

        /// <summary>
        /// Fills Missing nodes using the hoeffding test
        /// Loops through each branch identifying the nodes with missing child nodes
        /// and those that are complete for all but the last level in the tree 
        /// </summary>
        /// <param name="trees"></param>
        /// <param name="epsilon"></param>
        /// <returns></returns>
        public List<PolicyNode> FillMissingNodesHoeffding(List<PolicyNode> trees, double epsilon)
        {
            return FillMissingNodesHoeffding(trees, epsilon, string.Empty);
        }

        /// <summary>
        /// Fills missing nodes in a policy tree based on the difference between occurrences
        /// of child nodes
        /// </summary>
        /// <param name="trees">input partial trees</param>
        /// <param name="diff">max difference value</param>
        /// <param name="mergesFile">ouput merges file</param>
        /// <returns>filled in trees</returns>
        public List<PolicyNode> FillMissingNodesDifference(List<PolicyNode> trees, double diff, string mergesFile, Stopwatch stopwatch)
        {
            int mergeCount = 0;
            int hoeffdingTests = 0;
            int incompleteNodes = 0;

            foreach (PolicyNode rootNode in trees)
            {
                for (int i = 1; i < rootNode.GetTreeLength(); i++)
                {
                    List<PolicyNode> levelNodes = rootNode.GetNodesForLevel(i);
                    List<PolicyNode> levelCompleteNodes = new List<PolicyNode>();
                    List<PolicyNode> levelIncompleteNodes = new List<PolicyNode>();

                    //Divide nodes into two categories of complete and incomplete nodes
                    foreach (PolicyNode levelNode in levelNodes)
                    {
                        if (levelNode.IsComplete())
                        {
                            levelCompleteNodes.Add(levelNode);
                        }
                        else
                        {
                            levelIncompleteNodes.Add(levelNode);
                        }
                    }

                    //loop through complete nodes and run similarity test on all complete nodes
                    foreach (PolicyNode incompleteNode in levelIncompleteNodes)
                    {
                        incompleteNodes++;

                        foreach (PolicyNode completeNode in levelCompleteNodes)
                        {
                            hoeffdingTests++;

                            if (incompleteNode.SimilarNodeTest(completeNode, diff))
                            {
                                incompleteNode.MergeWithNode(completeNode);
                                mergeCount++;
                            }
                        }
                    }
                }
            }

            //double learningTime = DateTime.Now.Millisecond - startTime;
            stopwatch.Stop();
            double learningTime = stopwatch.ElapsedMilliseconds;

            if (mergesFile != string.Empty)
            {
                using (StreamWriter writer = new StreamWriter(mergesFile))
                {
                    writer.WriteLine("NumberOfTrees,IncompleteNodes,DiffTests,Merges,IncompleteNodes/Tree,Merges/Tree,LearningTime");
                    writer.WriteLine(trees.Count() + "," + incompleteNodes + "," + hoeffdingTests + "," + mergeCount + ","
                        + ((double)incompleteNodes / (double)trees.Count()) + "," + ((double)mergeCount / (double)trees.Count()) + "," + learningTime + ",");
                }
            }

            //Console.WriteLine("Difference Tests : " + hoeffdingTests);
            //Console.WriteLine("Merge Count : " + mergeCount);

            return trees;
        }

        /// <summary>
        /// Fills missing nodes in a policy tree based on the difference between occurrences
        /// of child nodes
        /// </summary>
        /// <param name="trees">input partial trees</param>
        /// <param name="diff">max difference value</param>
        /// <returns>filled in trees</returns>
        public List<PolicyNode> FillMissingNodesDifference(List<PolicyNode> trees, double diff)
        {
            return FillMissingNodesDifference(trees, diff, string.Empty, new Stopwatch());
        }

        /// <summary>
        /// Uses the zero history actions to partition the data before applying existing methods to
        /// build the policy tree
        /// </summary>
        /// <param name="inputFiles">input files to learn from</param>
        /// <param name="obsCol">observation column name</param>
        /// <param name="actionCol">action column name</param>
        /// <param name="observations">list of potential observations</param>
        /// <param name="maxHorizon">maximim horizon length</param>
        /// <returns></returns>
        public List<PolicyNode> LearnFromUnknownHorizonHistories(List<string> inputFiles, string obsCol, string actionCol, List<string> observations, int maxHorizon)
        {
            //Get histories to compare with
            Dictionary<int, List<PolicyNode>> compareHistories = BuildDifferentLengthHistoriesForFiles(inputFiles, obsCol,
                actionCol, observations, maxHorizon);

            List<PolicyNode> trees = new List<PolicyNode>();
            
            foreach (string inputFile in inputFiles)
            {
                CsvParser parser = new CsvParser();
                List<string> headings = parser.GetColumnHeadings(inputFile);
                List<List<string>> rows = parser.ParseRows(inputFile, true);

                int obsIndex = headings.IndexOf(obsCol);
                int actionIndex = headings.IndexOf(actionCol);

                //for each row in data try a different history length comparing with the comparable histories
                for (int i = 1; i < rows.Count; i++)
                {
                    int maxFound = 0;
                    //PolicyNode tempRootNode = new PolicyNode(observations, hist.ToList(), string.Empty);
                    List<string> hist = new List<string>();

                    for (int k = 1; (k < maxHorizon) && ((k - maxFound) >= 1); k++)
                    {
                        List<List<string>> currentBranch = new List<List<string>>();

                        for (int j = 0; (j <= k) && ((i - j) >= 0); j++)
                        {
                            currentBranch.Add(rows[i - j]);
                        }

                        hist = BranchToHistory(currentBranch, obsIndex, actionIndex);

                        //compare temp root node with all other nodes in dictionary of same length
                        int sameHistories = 0;
                        int sameActions = 0;

                        foreach (PolicyNode root in compareHistories[k])
                        {
                            if (root.HasSameHistory(hist.ToList()))
                            {
                                sameHistories ++;
                                if (root.IsHistoryCompatible(hist.ToList(), string.Empty))
                                {
                                    sameActions++;
                                }
                            }
                        }

                        //If all same histories lead to same action then history must always lead to action
                        //more than 1 same history is required because there is guaranteed to be at least 1
                        //in comparison trees
                        if (sameHistories > 1 && sameHistories == sameActions)
                        {
                            maxFound = k;
                        }
                    }

                    //If maximum length found add it / merge with to the list of roots
                    if (maxFound > 0)
                    {
                        trees = AddBranchToPolicyTrees(trees, hist, observations, string.Empty);
                    }
                }
            }

            return trees;
        }

        /// <summary>
        /// Loops through all files building different length histories histories for each up to the max defined length
        /// </summary>
        /// <param name="inputFiles">list of input files</param>
        /// <param name="obsCol">observation column name</param>
        /// <param name="actionCol">action column name</param>
        /// <param name="observations">list of potential observations</param>
        /// <param name="maxHorizon">max possible horizon length</param>
        /// <returns>completed policy trees</returns>
        private Dictionary<int, List<PolicyNode>> BuildDifferentLengthHistoriesForFiles(List<string> inputFiles, string obsCol, string actionCol, List<string> observations, int maxHorizon)
        {
            Dictionary<int, List<PolicyNode>> histories = new Dictionary<int, List<PolicyNode>>();

            foreach (string inputFile in inputFiles)
            {
                Dictionary<int, List<PolicyNode>> tempHistory = BuildDifferentLengthHistoriesForFile(inputFile, obsCol,
                    actionCol, observations, maxHorizon);

                foreach (KeyValuePair<int, List<PolicyNode>> keyValuePair in tempHistory)
                {
                    if (histories.ContainsKey(keyValuePair.Key))
                    {
                        histories[keyValuePair.Key].AddRange(keyValuePair.Value);
                    }
                    else
                    {
                        histories.Add(keyValuePair.Key, keyValuePair.Value);
                    }
                }
            }

            return histories;
        }

        /// <summary>
        /// Builds histories of differing lengths up to the maximum length, a history is a branch of a tree
        /// </summary>
        /// <param name="inputFile">input file to build histories from</param>
        /// <param name="obsCol">observation column name</param>
        /// <param name="actionCol">action column name</param>
        /// <param name="observations">list of potential observations</param>
        /// <param name="maxHorizon">max horizon length</param>
        /// <returns>dictionary of history lengths</returns>
        private Dictionary<int, List<PolicyNode>> BuildDifferentLengthHistoriesForFile(string inputFile, string obsCol, string actionCol, List<string> observations, int maxHorizon)
        {
            Dictionary<int, List<PolicyNode>> histories = new Dictionary<int, List<PolicyNode>>();

            CsvParser parser = new CsvParser();
            List<string> headings = parser.GetColumnHeadings(inputFile);
            List<List<string>> rows = parser.ParseRows(inputFile, true);

            int obsIndex = headings.IndexOf(obsCol);
            int actionIndex = headings.IndexOf(actionCol);

            for (int i = 2; i <= maxHorizon; i++)
            {
                List<PolicyNode> horizonHistories = new List<PolicyNode>();

                for (int j = 0; j < (rows.Count - i); j++)
                {
                    List<List<string>> currentBranch = new List<List<string>>();
                    for (int k = 0; k < i; k++)
                    {
                        currentBranch.Add(rows[j+k]);
                    }

                    //Convert the temp branch rows to a policy node branch
                    List<string> hist = BranchToHistory(currentBranch, obsIndex, actionIndex);
                    PolicyNode rootNode = new PolicyNode(observations, hist, string.Empty);
                    horizonHistories.Add(rootNode);
                }

                histories.Add(i-1, horizonHistories);
            }
            
            return histories;
        }
    }
}
